<!DOCTYPE html>
<html lang="en">
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content="Transform nn.Linear layers to 8-bit integer layers."/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content="LLM.int8() on GPT-NeoX"/>
    <meta name="twitter:description" content="Transform nn.Linear layers to 8-bit integer layers."/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/neox/utils/llm_int8.html"/>
    <meta property="og:title" content="LLM.int8() on GPT-NeoX"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="LLM.int8() on GPT-NeoX"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content="LLM.int8() on GPT-NeoX"/>
    <meta property="og:description" content="Transform nn.Linear layers to 8-bit integer layers."/>

    <title>LLM.int8() on GPT-NeoX</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../../pylit.css?v=1">
    <link rel="canonical" href="https://nn.labml.ai/neox/utils/llm_int8.html"/>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="../index.html">neox</a>
                <a class="parent" href="index.html">utils</a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/llm_int8.py" target="_blank">
                    View code on Github</a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1>LLM.int() on GPT-NeoX</h1>
<p>This implements a utility function to transform a <code  class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span></code>
 layer to LLM.int8() linear layer.</p>
<p><a href="https://arxiv.org/abs/eb2bcaee1d0011edaa66a71c10a887e7">LLM.int8() paper</a>  shows you can use int8 quantization while handling outliers to reduce memory footprint without performance degradation in large language models. They convert weights and inputs to scaled 8-bit integers and does matrix multiplication producing int32 results which is then converted back to float16 and rescaled. They show that in large langauge models, some features can give extreme values (outliers) that dominate the model&#x27;s output. These features get clamped in 8-bit integer space which causes the model performance to degrade. As a solution they pick these outliers (greater than a specified threshold) and compute their multiplications separately in float16 space. Since the percentage of outliers is around 0.01% this doesn&#x27;t increase memory usage, and prevents the model from degrading performance.</p>
<p>The code to transform GPT-NoeX layers is defined in <a href="../model.html#post_load_prepare">model.py</a>.</p>
<p>Here are example uses of GPT-NeoX with int8 quantization.</p>
<ul><li><a href="../samples/llm_int8.html">Generate Text</a> </li>
<li><a href="../evaluation/llm_int8.html">Run Evaluation Tests</a></li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">33</span><span></span></pre></div>
        </div>
    </div>
    <div class='section' id='section-1'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-1'>#</a>
            </div>
            <p>Import <a href="https://github.com/timdettmers/bitsandbytes"><code  class="highlight"><span></span><span class="n">bitsandbytes</span></code>
</a> package </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">34</span><span class="k">try</span><span class="p">:</span>
<span class="lineno">35</span>    <span class="kn">from</span> <span class="nn">bitsandbytes.nn</span> <span class="kn">import</span> <span class="n">Linear8bitLt</span><span class="p">,</span> <span class="n">Int8Params</span>
<span class="lineno">36</span><span class="k">except</span> <span class="ne">ImportError</span><span class="p">:</span>
<span class="lineno">37</span>    <span class="k">raise</span> <span class="ne">ImportError</span><span class="p">(</span><span class="s1">&#39;&#39;&#39;Please install `bitsandbytes` with `pip install bitsandbytes -U`&#39;&#39;&#39;</span><span class="p">)</span>
<span class="lineno">38</span>
<span class="lineno">39</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">40</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-2'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-2'>#</a>
            </div>
            <h2>Transform a <code  class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span></code>
 layer to LLM.int8() linear layer</h2>
<ul><li><code  class="highlight"><span></span><span class="n">linear_module</span></code>
  is the <code  class="highlight"><span></span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span></code>
 layer to transform </li>
<li><code  class="highlight"><span></span><span class="n">device</span></code>
  is the device of the model </li>
<li><code  class="highlight"><span></span><span class="n">threshold</span></code>
  is the threshold <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord mathnormal" style="margin-right:0.0037em;">α</span></span></span></span></span> to use for outlier detection</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">43</span><span class="k">def</span> <span class="nf">make_llm_int8_linear</span><span class="p">(</span><span class="n">linear_module</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">6.0</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-3'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-3'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">53</span>    <span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">linear_module</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-4'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-4'>#</a>
            </div>
            <p>Create an empty Linear8bitLt module </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">56</span>    <span class="n">int8_lin</span> <span class="o">=</span> <span class="n">Linear8bitLt</span><span class="p">(</span>
<span class="lineno">57</span>        <span class="n">linear_module</span><span class="o">.</span><span class="n">in_features</span><span class="p">,</span>
<span class="lineno">58</span>        <span class="n">linear_module</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span>
<span class="lineno">59</span>        <span class="n">linear_module</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">,</span>
<span class="lineno">60</span>        <span class="n">has_fp16_weights</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">61</span>        <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span>
<span class="lineno">62</span>    <span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-5'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-5'>#</a>
            </div>
            <p>Quantize the weights </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">65</span>    <span class="n">int8_lin</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="s1">&#39;weight&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">Int8Params</span><span class="p">(</span><span class="n">linear_module</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">cpu</span><span class="p">(),</span>
<span class="lineno">66</span>                                                <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="lineno">67</span>                                                <span class="n">has_fp16_weights</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-6'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-6'>#</a>
            </div>
            <p>Set the bias in float16 space </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">70</span>    <span class="k">if</span> <span class="n">linear_module</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">71</span>        <span class="n">int8_lin</span><span class="o">.</span><span class="n">_parameters</span><span class="p">[</span><span class="s1">&#39;bias&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">linear_module</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">data</span><span class="p">,</span>
<span class="lineno">72</span>                                                    <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-7'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-7'>#</a>
            </div>
            <p> </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">75</span>    <span class="k">return</span> <span class="n">int8_lin</span></pre></div>
        </div>
    </div>
    <div class='footer'>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>